import pandas as pd
import matplotlib.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os

from curlyBrace import curlyBrace


def main():
    # Load the data
    file_path = os.path.join(os.path.dirname(__file__), "recletteroutcomes.csv")
    df = pd.read_csv(file_path)

    # Generate dummy variables for treatment
    df["Male"] = (df["treatment"] == "M").astype(int)
    df["Female_exogenous"] = (df["treatment"] == "F3").astype(int)
    df["Female_endogenous"] = (df["treatment"] == "F2").astype(int)

    # Compute means and error bars for "wage"
    means_wage = []
    errors_wage = []
    print("Standard Errors for 'wage':")
    for group_col in ["Male", "Female_exogenous", "Female_endogenous"]:
        mean, error = compute_means_and_errors(df, "wage", group_col, 1)
        means_wage.append(mean)
        errors_wage.append(error*1.96)
        print(f"{group_col}: mean = {mean:.2f}, SE = {error:.2f}") # means and SEs same as produced in Stata

    # Compute means and error bars for "hiring likelihood"
    means_next = []
    errors_next = []
    print("\nStandard Errors for 'next':")
    for group_col in ["Male", "Female_exogenous", "Female_endogenous"]:
        mean, error = compute_means_and_errors(df, "next", group_col, 1)
        means_next.append(mean)
        errors_next.append(error*1.96)
        print(f"{group_col}: mean = {mean:.2f}, SE = {error:.2f}") # means and SEs same as produced in Stata

    # Create hiring likelihood graph
    graph(os.path.splitext(os.path.basename(__file__))[0] + 'a.png',
          means_next,
          errors_next,
          ["Men", "Counterfactual \n Women", "Women"],
          'maroon',
          "Hiring Likelihood",
          (7.5, max(means_next) + 0.4))

    # Create wage graph
    graph(os.path.splitext(os.path.basename(__file__))[0] + 'b.png',
          means_wage,
          errors_wage,
          ["Men", "Counterfactual \n Women", "Women"],
          'royalblue',
          "Prospective Wage",
          (20, max(means_wage) + 5))


# Define function to compute means and standard errors
def compute_means_and_errors(df, question, group_col, group_value):
    # Filter the data based on the group and question
    subset = df[(df['question'] == question) & (df[group_col] == group_value)]

    # Calculate mean and standard error for the 'answer' column
    means = subset['answer'].mean()
    error = subset['answer'].std() / np.sqrt(len(subset))

    return means, error


# Graph function
def graph(outfile, means, errors, categories, color, title, y_limit):
    fig, ax = plt.subplots(figsize=(10, 7))
    bars = ax.bar(categories, means, yerr=errors, color=color, capsize=5, width=0.5)

    # Adding dashed lines for Systemic and Direct
    ax.hlines(y=means[0], xmin=0.25, xmax=2.25, colors='black', linestyles='dashed', label="Systemic")
    ax.hlines(y=means[1], xmin=1.25, xmax=2.25, colors='green', linestyles='dashed', label="Direct")

    total_line_y = max(means)  # highest point
    ax.vlines(x=2.25, ymin=means[2], ymax=total_line_y, colors='black')
    ax.vlines(x=1.75, ymin=means[2], ymax=total_line_y, colors='black')

    avg_text_y      = (means[0] + means[1]) / 2
    systemic_text_y = (means[1] + means[2]) / 2
    total_text_y    = (means[0] + means[2]) / 2

    # Adding annotations
    ax.text(2.3,  avg_text_y,      'Direct',   ha='left', va='center', fontsize=16, color='green',     fontweight='bold')
    ax.text(2.3,  systemic_text_y, 'Systemic', ha='left', va='center', fontsize=16, color='gold',      fontweight='bold')
    ax.text(1.45, total_text_y,    'Total',    ha='left', va='center', fontsize=16, fontweight='bold')

    # Add 'reverse' note
    if means[0] < means[1]:
        offset    = transforms.ScaledTranslation(0, -12/72, ax.figure.dpi_scale_trans)
        transform = ax.transData + offset
        ax.text(2.3,
                avg_text_y,
                '(reverse)',
                ha         = 'left',
                va         = 'center',
                fontsize   = 12,
                style      = 'italic',
                color      = 'green',
                transform  = transform)

    # ----------------------------------------------------------------------
    # NOTE: Added fillings, brace from inside python; old note below
    #
    #   Note: The yellow and green areas, as well as brackets highlighting the
    #   size of direct and systemic discrimination were added to this graph
    #   outside of Python.
    #
    # ----------------------------------------------------------------------

    # Adjust labels
    ax.set_xticks(range(len(categories)))
    ax.set_xticklabels(categories, fontsize=16)
    plt.yticks(fontsize=16)

    ax.set_ylim(y_limit)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)

    # Fill differences
    axspan = (ax.get_ylim()[1] - ax.get_ylim()[0])
    ax.axvspan(1.75, 2.25,
               ymin  = (means[2] - ax.get_ylim()[0]) / axspan,
               ymax  = (means[1] - ax.get_ylim()[0]) / axspan,
               color = 'lemonchiffon')

    yfr = min(means[0], means[1])
    yto = max(means[0], means[1])
    ax.axvspan(1.75, 2.25,
               ymin      = (yfr - ax.get_ylim()[0]) / axspan,
               ymax      = (yto - ax.get_ylim()[0]) / axspan,
               facecolor = 'none',
               edgecolor = 'green',
               hatch     = '////',
               linewidth = 0)

    # Brace next to total; from https://github.com/iruletheworld/matplotlib-curly-brace
    yr = y_limit[-1] - y_limit[0]
    total_text_x = 1.45 + 0.295
    curlyBrace(fig,
               ax,
               (total_text_x, means[2] - (0.0025 if yr < 5 else 0.025)),
               (total_text_x, means[0] - (0.0025 if yr < 5 else 0.025)), 
               0.0425 if yr < 5 else 0.085,
               color='black',
               lw=2)

    save_plot(outfile, ax, fontsize=16, transparent=False)
    # plt.show()


def save_plot(outfile, ax, fontsize=20, transparent=True, fig=None):
    # -----------------
    # In older versions
    # -----------------
    # for tick in ax.yaxis.get_major_ticks():
    #     tick.label.set_fontsize(fontsize)
    #
    # for tick in ax.xaxis.get_major_ticks():
    #     tick.label.set_fontsize(fontsize)

    for tick in ax.get_yticklabels():
        tick.set_fontsize(fontsize)

    for tick in ax.get_xticklabels():
        tick.set_fontsize(fontsize)

    plt.savefig(outfile, bbox_inches='tight', dpi=300, transparent=transparent)


if __name__ == "__main__":
    main()
